import os
import json
import inspect
import argparse
import pandas as pd
from tqdm.auto import tqdm
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split

from _models.model import get_embedding_func_batched
from _datasets.data import DatasetConfig
from utils.transform_utils import *
from utils.string_utils import *
from utils.metrics import *


class RobustnessExperimentConfig:

    def __init__(
        self,
        dataset_name: str,
        num_examples: int,
        model_name: str = "BAAI/bge-small-en-v1.5",
        max_length: int = 8192,
    ):
        self.model_name = model_name
        self.dataset_name = dataset_name
        self.dataset_config = DatasetConfig(dataset_name, num_examples)
        self.dataset = self.dataset_config.get_dataset(True, max_length)
        print(f"Dataset {dataset_name} loaded.")

        self.embedding_func = get_embedding_func_batched(model_name)
        self.similarity_data = pd.DataFrame(self.dataset)
        self.results = {}

        # Create directory for model data if it doesn't exist
        self.model_data_path = os.path.join(
            "data", self.model_name.replace("/", "_")
        )  # Replacing '/' with '_' to avoid subdirectories
        os.makedirs(self.model_data_path, exist_ok=True)

    def run(self):
        self.generate_embeddings(
            embedding_func=self.embedding_func,
            **{"model_name": self.model_name, "use_gpu": True},
        )
        print("Generated embeddings.")

        self.calculate_similarities()
        print("Calculated similarities.")

        self.fit_ensembling()
        print("Fitted ensembling.")

        self.get_results()
        print("Got results.")

        # Save the similarity data to a CSV file in the model-specific directory
        file_path = f"{self.model_data_path}/{self.dataset_config.name}_robustness.pkl"
        self.similarity_data.to_pickle(file_path)
        print(f"Saved data to {file_path}.")

        # Save the results to a JSON file in the model-specific directory
        results_file_path = (
            f"{self.model_data_path}/{self.dataset_config.name}_robustness.json"
        )
        with open(results_file_path, "w") as f:
            self.results = {
                k1: {k2: float(v2) for k2, v2 in v1.items()}
                for k1, v1 in self.results.items()
            }
            f.write(json.dumps(self.results))
        print(f"Saved results to {results_file_path}.")

    def generate_embeddings(self, embedding_func, **kwargs):
        # For models that are not from huggingface
        source_code = inspect.getsource(embedding_func)
        if not "huggingface" in source_code:
            kwargs["model"] = kwargs["model_name"]
            del kwargs["model_name"]
            del kwargs["use_gpu"]

        for column in tqdm(self.similarity_data.columns, desc="Generating embeddings"):
            embeddings_column = f"embeddings_{column}"
            if column in self.similarity_data:
                embeds = embedding_func(
                    prompts=self.similarity_data[column].dropna().tolist(),
                    pbar=False,
                    **kwargs,
                )
                self.similarity_data[embeddings_column] = (
                    embeds if isinstance(embeds, list) else embeds.tolist()
                )
            else:
                print(f"Warning: Column {column} does not exist in the DataFrame")

    def calculate_similarities(self):
        for fn in tqdm(
            [
                levenshtein_ratio,
                jaccard_similarity,
                cosine_similarity,
                bm25_score,
                rouge_score,
            ],
            desc="Calculating similarities",
        ):
            metric = f'{fn.__name__.split("_")[0]}_similarity'
            prefix = "embeddings_" if fn == cosine_similarity else ""
            o_prefix = prefix + "original"
            # Calculate the similarity metrics
            original_to_summary = fn(
                self.similarity_data[f"{o_prefix}"],
                self.similarity_data[f"{prefix}summary"],
            )
            original_to_negated = fn(
                self.similarity_data[f"{o_prefix}"],
                self.similarity_data[f"{o_prefix}_negated"],
            )
            original_to_sentence_shuffled = fn(
                self.similarity_data[f"{o_prefix}"],
                self.similarity_data[f"{o_prefix}_sentence_shuffled"],
            )
            original_to_word_shuffled = fn(
                self.similarity_data[f"{o_prefix}"],
                self.similarity_data[f"{o_prefix}_word_shuffled"],
            )
            original_to_pruned = fn(
                self.similarity_data[f"{o_prefix}"],
                self.similarity_data[f"{o_prefix}_pruned"],
            )
            original_to_random_upper = fn(
                self.similarity_data[f"{o_prefix}"],
                self.similarity_data[f"{o_prefix}_random_upper"],
            )
            original_to_numerized = fn(
                self.similarity_data[f"{o_prefix}"],
                self.similarity_data[f"{o_prefix}_numerized"],
            )

            # Add the similarity metrics to the DataFrame
            self.similarity_data[f"{metric}_original_to_summary"] = original_to_summary
            self.similarity_data[f"{metric}_original_to_negated"] = original_to_negated
            self.similarity_data[f"{metric}_original_to_sentence_shuffled"] = (
                original_to_sentence_shuffled
            )
            self.similarity_data[f"{metric}_original_to_word_shuffled"] = (
                original_to_word_shuffled
            )
            self.similarity_data[f"{metric}_original_to_pruned"] = original_to_pruned
            self.similarity_data[f"{metric}_original_to_random_upper"] = (
                original_to_random_upper
            )
            self.similarity_data[f"{metric}_original_to_numerized"] = (
                original_to_numerized
            )

            # Aggregate similarity metrics for scoring
            self.similarity_data[f"{metric}_summary_over_semantic"] = (
                (original_to_summary > original_to_negated)
                & (original_to_summary > original_to_sentence_shuffled)
                & (original_to_summary > original_to_word_shuffled)
            )

            self.similarity_data[f"{metric}_superficial_over_summary"] = (
                (original_to_pruned > original_to_summary)
                & (original_to_random_upper > original_to_summary)
                & (original_to_numerized > original_to_summary)
            )

            self.similarity_data[f"{metric}_superficial_over_semantic"] = (
                (original_to_pruned > original_to_negated)
                & (original_to_pruned > original_to_sentence_shuffled)
                & (original_to_pruned > original_to_word_shuffled)
                & (original_to_random_upper > original_to_negated)
                & (original_to_random_upper > original_to_sentence_shuffled)
                & (original_to_random_upper > original_to_word_shuffled)
                & (original_to_numerized > original_to_negated)
                & (original_to_numerized > original_to_sentence_shuffled)
                & (original_to_numerized > original_to_word_shuffled)
            )

    def get_transform_similarity(self, transform, score=None):
        columns = [c for c in self.similarity_data.columns if transform in c]

        if score is None:
            out = self.similarity_data[columns]
            out = out.rename(
                columns={c: "_".join(c.split("_")[:2]).strip() for c in columns}
            )
            return out

        score_col = pd.Series([score] * len(self.similarity_data), name="score")
        out = pd.concat([self.similarity_data[columns], score_col], axis=1)
        out = out.rename(
            columns={c: "_".join(c.split("_")[:2]).strip() for c in columns}
        )
        return out

    def fit_ensembling(self):
        metric = "ensembled_similarity"

        data = pd.concat(
            [
                self.get_transform_similarity("original_to_summary", 0.5),
                self.get_transform_similarity("original_to_negated", 0),
                self.get_transform_similarity("original_to_sentence_shuffled", 0),
                self.get_transform_similarity("original_to_word_shuffled", 0),
                self.get_transform_similarity("original_to_pruned", 1),
                self.get_transform_similarity("original_to_random_upper", 1),
                self.get_transform_similarity("original_to_numerized", 1),
            ],
            axis=0,
        )

        summary_over_semantics = []
        superficial_over_semantics = []
        superficial_over_summarys = []
        for i in tqdm(range(1000), desc="Ensembling"):
            X, y = self.similarity_data, np.ones(len(self.similarity_data))
            X_train, X_test, y_train, y_test = train_test_split(
                X, y, test_size=0.2, random_state=i
            )

            train_data = data.loc[X_train.index]
            test_data = data.loc[X_test.index]
            X_train, y_train = train_data.drop("score", axis=1), train_data["score"]
            X_test, y_test = test_data.drop("score", axis=1), test_data["score"]
            test_idx = X_test.index.unique()

            ensemble = LinearRegression(fit_intercept=False)
            ensemble.fit(X_train, y_train)

            original_to_summary = ensemble.predict(
                self.get_transform_similarity("original_to_summary").loc[test_idx]
            )
            original_to_negated = ensemble.predict(
                self.get_transform_similarity("original_to_negated").loc[test_idx]
            )
            original_to_sentence_shuffled = ensemble.predict(
                self.get_transform_similarity("original_to_sentence_shuffled").loc[
                    test_idx
                ]
            )
            original_to_word_shuffled = ensemble.predict(
                self.get_transform_similarity("original_to_word_shuffled").loc[test_idx]
            )
            original_to_pruned = ensemble.predict(
                self.get_transform_similarity("original_to_pruned").loc[test_idx]
            )
            original_to_random_upper = ensemble.predict(
                self.get_transform_similarity("original_to_random_upper").loc[test_idx]
            )
            original_to_numerized = ensemble.predict(
                self.get_transform_similarity("original_to_numerized").loc[test_idx]
            )

            summary_over_semantic = (
                (original_to_summary > original_to_negated)
                & (original_to_summary > original_to_sentence_shuffled)
                & (original_to_summary > original_to_word_shuffled)
            )

            superficial_over_summary = (
                (original_to_pruned > original_to_summary)
                & (original_to_random_upper > original_to_summary)
                & (original_to_numerized > original_to_summary)
            )

            superficial_over_semantic = (
                (original_to_pruned > original_to_negated)
                & (original_to_pruned > original_to_sentence_shuffled)
                & (original_to_pruned > original_to_word_shuffled)
                & (original_to_random_upper > original_to_negated)
                & (original_to_random_upper > original_to_sentence_shuffled)
                & (original_to_random_upper > original_to_word_shuffled)
                & (original_to_numerized > original_to_negated)
                & (original_to_numerized > original_to_sentence_shuffled)
                & (original_to_numerized > original_to_word_shuffled)
            )

            summary_over_semantics.append(summary_over_semantic.mean())
            superficial_over_summarys.append(superficial_over_summary.mean())
            superficial_over_semantics.append(superficial_over_semantic.mean())

        self.results[metric] = {
            "summary_over_semantic": np.mean(summary_over_semantics),
            "superficial_over_summary": np.mean(superficial_over_summarys),
            "superficial_over_semantic": np.mean(superficial_over_semantics),
        }
        self.results[metric]["overall"] = (
            self.results[metric]["summary_over_semantic"]
            + self.results[metric]["superficial_over_summary"]
            + self.results[metric]["superficial_over_semantic"]
        ) / 3

    def get_results(self):
        for metric in metrics:
            metric = f"{metric}_similarity"
            self.results[metric] = {
                "summary_over_semantic": self.similarity_data[
                    f"{metric}_summary_over_semantic"
                ].mean(),
                "superficial_over_summary": self.similarity_data[
                    f"{metric}_superficial_over_summary"
                ].mean(),
                "superficial_over_semantic": self.similarity_data[
                    f"{metric}_superficial_over_semantic"
                ].mean(),
            }
            self.results[metric]["overall"] = (
                self.results[metric]["summary_over_semantic"]
                + self.results[metric]["superficial_over_summary"]
                + self.results[metric]["superficial_over_semantic"]
            ) / 3

        return self.results


def main(
    dataset_name="scores",
    num_examples=5,
    model_name="embed-english-v3.0",
    max_length=8192,
):
    exp_config = RobustnessExperimentConfig(
        dataset_name,
        num_examples,
        model_name,
        max_length,
    )
    exp_config.run()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset_name", type=str, default="scientific_papers")
    parser.add_argument("--num_examples", type=int, default=5)
    parser.add_argument("--model_name", type=str, default="embed-english-v3.0")
    parser.add_argument("--max_length", type=int, default=8192)
    args = parser.parse_args()

    dataset_name = args.dataset_name
    num_examples = args.num_examples
    model_name = args.model_name
    max_length = args.max_length

    main(
        dataset_name,
        num_examples,
        model_name,
        max_length,
    )
